{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n",
      "/data/users/fyx/.local/python3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n",
      "  return f(*args, **kwds)\n",
      "/data/users/fyx/.local/python3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n",
      "  return f(*args, **kwds)\n"
     ]
    }
   ],
   "source": [
    "import keras\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matchzoo as mz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_pack = mz.datasets.wiki_qa.load_data('train', task='ranking')\n",
    "valid_pack = mz.datasets.wiki_qa.load_data('dev', task='ranking', filter=True)\n",
    "predict_pack = mz.datasets.wiki_qa.load_data('test', task='ranking', filter=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit: 100%|██████████| 2118/2118 [00:00<00:00, 7736.47it/s]\n",
      "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit: 100%|██████████| 18841/18841 [00:04<00:00, 3956.36it/s]\n",
      "Processing text_right with append: 100%|██████████| 18841/18841 [00:00<00:00, 743868.61it/s]\n",
      "Building FrequencyFilterUnit from a datapack.: 100%|██████████| 18841/18841 [00:00<00:00, 154893.95it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 82871.96it/s]\n",
      "Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 318920.69it/s]\n",
      "Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 666556.02it/s]\n",
      "Building VocabularyUnit from a datapack.: 100%|██████████| 234249/234249 [00:00<00:00, 2372671.77it/s]\n",
      "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit: 100%|██████████| 2118/2118 [00:00<00:00, 7708.43it/s]\n",
      "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit: 100%|██████████| 18841/18841 [00:04<00:00, 3922.76it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 157160.24it/s]\n",
      "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 157971.65it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 167186.51it/s]\n",
      "Processing length_left with len: 100%|██████████| 2118/2118 [00:00<00:00, 502490.86it/s]\n",
      "Processing length_right with len: 100%|██████████| 18841/18841 [00:00<00:00, 652521.17it/s]\n",
      "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 98155.19it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 87504.41it/s]\n"
     ]
    }
   ],
   "source": [
    "preprocessor = mz.preprocessors.BasicPreprocessor(fixed_length_left=10, fixed_length_right=100, remove_stop_words=True)\n",
    "train_pack_processed = preprocessor.fit_transform(train_pack)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit: 100%|██████████| 122/122 [00:00<00:00, 5830.20it/s]\n",
      "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit: 100%|██████████| 1115/1115 [00:00<00:00, 3952.66it/s]\n",
      "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 119864.90it/s]\n",
      "Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 83434.71it/s]\n",
      "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 125164.57it/s]\n",
      "Processing length_left with len: 100%|██████████| 122/122 [00:00<00:00, 98803.84it/s]\n",
      "Processing length_right with len: 100%|██████████| 1115/1115 [00:00<00:00, 340615.36it/s]\n",
      "Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 25928.81it/s]\n",
      "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 58033.74it/s]\n",
      "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit: 100%|██████████| 237/237 [00:00<00:00, 7003.16it/s]\n",
      "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit: 100%|██████████| 2300/2300 [00:00<00:00, 3970.78it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 171549.23it/s]\n",
      "Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 129822.39it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 182336.92it/s]\n",
      "Processing length_left with len: 100%|██████████| 237/237 [00:00<00:00, 109224.27it/s]\n",
      "Processing length_right with len: 100%|██████████| 2300/2300 [00:00<00:00, 361374.76it/s]\n",
      "Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 72436.79it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 13800.94it/s]\n"
     ]
    }
   ],
   "source": [
    "valid_pack_processed = preprocessor.transform(valid_pack)\n",
    "predict_pack_processed = preprocessor.transform(predict_pack)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "ranking_task = mz.tasks.Ranking(loss=mz.losses.RankHingeLoss())\n",
    "ranking_task.metrics = [\n",
    "    mz.metrics.NormalizedDiscountedCumulativeGain(k=3),\n",
    "    mz.metrics.NormalizedDiscountedCumulativeGain(k=5),\n",
    "    mz.metrics.MeanAveragePrecision()\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Parameter \"name\" set to ArcI.\n",
      "Parameter \"embedding_trainable\" set to True.\n",
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "text_left (InputLayer)          (None, 10)           0                                            \n",
      "__________________________________________________________________________________________________\n",
      "text_right (InputLayer)         (None, 100)          0                                            \n",
      "__________________________________________________________________________________________________\n",
      "embedding (Embedding)           multiple             4963800     text_left[0][0]                  \n",
      "                                                                 text_right[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "conv1d_1 (Conv1D)               (None, 10, 128)      115328      embedding[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "conv1d_2 (Conv1D)               (None, 100, 128)     115328      embedding[1][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "max_pooling1d_1 (MaxPooling1D)  (None, 2, 128)       0           conv1d_1[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "max_pooling1d_2 (MaxPooling1D)  (None, 25, 128)      0           conv1d_2[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "flatten_1 (Flatten)             (None, 256)          0           max_pooling1d_1[0][0]            \n",
      "__________________________________________________________________________________________________\n",
      "flatten_2 (Flatten)             (None, 3200)         0           max_pooling1d_2[0][0]            \n",
      "__________________________________________________________________________________________________\n",
      "concatenate_1 (Concatenate)     (None, 3456)         0           flatten_1[0][0]                  \n",
      "                                                                 flatten_2[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "dropout_1 (Dropout)             (None, 3456)         0           concatenate_1[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "dense_1 (Dense)                 (None, 100)          345700      dropout_1[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "dense_2 (Dense)                 (None, 1)            101         dense_1[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dense_3 (Dense)                 (None, 1)            2           dense_2[0][0]                    \n",
      "==================================================================================================\n",
      "Total params: 5,540,259\n",
      "Trainable params: 5,540,259\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model = mz.models.ArcI()\n",
    "model.params['input_shapes'] = preprocessor.context['input_shapes']\n",
    "model.params['task'] = ranking_task\n",
    "model.params['embedding_input_dim'] = preprocessor.context['vocab_size']\n",
    "model.params['embedding_output_dim'] = 300\n",
    "model.params['num_blocks'] = 1\n",
    "model.params['left_filters'] = [128]\n",
    "model.params['left_kernel_sizes'] = [3]\n",
    "model.params['left_pool_sizes'] = [4]\n",
    "model.params['right_filters'] = [128]\n",
    "model.params['right_kernel_sizes'] = [3]\n",
    "model.params['right_pool_sizes'] = [4]\n",
    "model.params['conv_activation_func']= 'relu'\n",
    "model.params['mlp_num_layers'] = 1\n",
    "model.params['mlp_num_units'] = 100\n",
    "model.params['mlp_num_fan_out'] = 1 \n",
    "model.params['mlp_activation_func'] = 'relu' \n",
    "model.params['dropout_rate'] = 0.9\n",
    "model.params['optimizer'] = 'adadelta'\n",
    "model.guess_and_fill_missing_params()\n",
    "model.build()\n",
    "model.compile()\n",
    "model.backend.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "glove_embedding = mz.datasets.embeddings.load_glove_embedding(dimension=300)\n",
    "embedding_matrix = glove_embedding.build_matrix(preprocessor.context['vocab_unit'].state['term_index'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.load_embedding_matrix(embedding_matrix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_x, pred_y = predict_pack_processed[:].unpack()\n",
    "evaluate = mz.callbacks.EvaluateAllMetrics(model, x=pred_x, y=pred_y, batch_size=len(pred_x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "102"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_generator = mz.PairDataGenerator(train_pack_processed, num_dup=2, num_neg=1, batch_size=20)\n",
    "len(train_generator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/30\n",
      "102/102 [==============================] - 11s 107ms/step - loss: 0.8801\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.49344952173663026 - normalized_discounted_cumulative_gain@5(0): 0.5481885983527495 - mean_average_precision(0): 0.5076126965212019\n",
      "Epoch 2/30\n",
      "102/102 [==============================] - 9s 91ms/step - loss: 0.7368\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5102411063977357 - normalized_discounted_cumulative_gain@5(0): 0.5761653660625962 - mean_average_precision(0): 0.5331800108222349\n",
      "Epoch 3/30\n",
      "102/102 [==============================] - 10s 94ms/step - loss: 0.6898\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5307537256185866 - normalized_discounted_cumulative_gain@5(0): 0.595153974435387 - mean_average_precision(0): 0.5550941013705694\n",
      "Epoch 4/30\n",
      "102/102 [==============================] - 10s 94ms/step - loss: 0.6072\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5392971214856643 - normalized_discounted_cumulative_gain@5(0): 0.6017646212562336 - mean_average_precision(0): 0.5581923086473634\n",
      "Epoch 5/30\n",
      "102/102 [==============================] - 9s 93ms/step - loss: 0.5444\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.545570586358596 - normalized_discounted_cumulative_gain@5(0): 0.6049780109639101 - mean_average_precision(0): 0.5684973963288421\n",
      "Epoch 6/30\n",
      "102/102 [==============================] - 10s 95ms/step - loss: 0.5070\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5347786399190819 - normalized_discounted_cumulative_gain@5(0): 0.6054129486520806 - mean_average_precision(0): 0.568100040489281\n",
      "Epoch 7/30\n",
      "102/102 [==============================] - 10s 93ms/step - loss: 0.4304\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5322553843756904 - normalized_discounted_cumulative_gain@5(0): 0.6065853877261808 - mean_average_precision(0): 0.5629468813776334\n",
      "Epoch 8/30\n",
      "102/102 [==============================] - 10s 99ms/step - loss: 0.3982\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5324907557092438 - normalized_discounted_cumulative_gain@5(0): 0.6060483502272122 - mean_average_precision(0): 0.5613403049499233\n",
      "Epoch 9/30\n",
      "102/102 [==============================] - 10s 96ms/step - loss: 0.3428\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5444374156960796 - normalized_discounted_cumulative_gain@5(0): 0.6030120572247202 - mean_average_precision(0): 0.5657362977746447\n",
      "Epoch 10/30\n",
      "102/102 [==============================] - 9s 92ms/step - loss: 0.3342\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.526410405567667 - normalized_discounted_cumulative_gain@5(0): 0.5954696042181865 - mean_average_precision(0): 0.5512490386671325\n",
      "Epoch 11/30\n",
      "102/102 [==============================] - 9s 93ms/step - loss: 0.3037\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5384502625199 - normalized_discounted_cumulative_gain@5(0): 0.6088358522925501 - mean_average_precision(0): 0.5645402879366022\n",
      "Epoch 12/30\n",
      "102/102 [==============================] - 10s 99ms/step - loss: 0.2727\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5251405635498017 - normalized_discounted_cumulative_gain@5(0): 0.5946647653920929 - mean_average_precision(0): 0.5520451968813845\n",
      "Epoch 13/30\n",
      "102/102 [==============================] - 9s 91ms/step - loss: 0.2461\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5354710367769009 - normalized_discounted_cumulative_gain@5(0): 0.5960505583502559 - mean_average_precision(0): 0.5547951485237411\n",
      "Epoch 14/30\n",
      "102/102 [==============================] - 10s 93ms/step - loss: 0.2427\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5400677801460133 - normalized_discounted_cumulative_gain@5(0): 0.6026908143535721 - mean_average_precision(0): 0.5616651152410645\n",
      "Epoch 15/30\n",
      "102/102 [==============================] - 9s 91ms/step - loss: 0.2244\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5419251661783024 - normalized_discounted_cumulative_gain@5(0): 0.6016534404162774 - mean_average_precision(0): 0.5621628477008224\n",
      "Epoch 16/30\n",
      "102/102 [==============================] - 10s 96ms/step - loss: 0.2043\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5407407941528163 - normalized_discounted_cumulative_gain@5(0): 0.6007534576258129 - mean_average_precision(0): 0.5617861729291366\n",
      "Epoch 17/30\n",
      "102/102 [==============================] - 9s 91ms/step - loss: 0.1900\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5301216867320792 - normalized_discounted_cumulative_gain@5(0): 0.5940089336864689 - mean_average_precision(0): 0.550788298323667\n",
      "Epoch 18/30\n",
      "102/102 [==============================] - 9s 90ms/step - loss: 0.1981\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5311515472651464 - normalized_discounted_cumulative_gain@5(0): 0.6026924614354999 - mean_average_precision(0): 0.5596961705319858\n",
      "Epoch 19/30\n",
      "102/102 [==============================] - 9s 89ms/step - loss: 0.1582\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5314472271855374 - normalized_discounted_cumulative_gain@5(0): 0.6044285862505763 - mean_average_precision(0): 0.5614493400066743\n",
      "Epoch 20/30\n",
      "102/102 [==============================] - 9s 90ms/step - loss: 0.1806\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.538068318282243 - normalized_discounted_cumulative_gain@5(0): 0.6021024437971346 - mean_average_precision(0): 0.5638790908355317\n",
      "Epoch 21/30\n",
      "102/102 [==============================] - 9s 91ms/step - loss: 0.1573\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5369749686464168 - normalized_discounted_cumulative_gain@5(0): 0.6062313822902646 - mean_average_precision(0): 0.5652298609390939\n",
      "Epoch 22/30\n",
      "102/102 [==============================] - 9s 91ms/step - loss: 0.1538\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5323565199294026 - normalized_discounted_cumulative_gain@5(0): 0.6003248369228174 - mean_average_precision(0): 0.5574479069119557\n",
      "Epoch 23/30\n",
      "102/102 [==============================] - 9s 90ms/step - loss: 0.1539\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5315018209119983 - normalized_discounted_cumulative_gain@5(0): 0.5975992710575344 - mean_average_precision(0): 0.5547110612561097\n",
      "Epoch 24/30\n",
      "102/102 [==============================] - 9s 89ms/step - loss: 0.1237\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.526002357678433 - normalized_discounted_cumulative_gain@5(0): 0.5972283202335096 - mean_average_precision(0): 0.5522531874374763\n",
      "Epoch 25/30\n",
      "102/102 [==============================] - 9s 90ms/step - loss: 0.1351\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5330796799745176 - normalized_discounted_cumulative_gain@5(0): 0.5989011418663146 - mean_average_precision(0): 0.5546921154350174\n",
      "Epoch 26/30\n",
      "102/102 [==============================] - 9s 89ms/step - loss: 0.1109\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5184401553259078 - normalized_discounted_cumulative_gain@5(0): 0.5881541293733173 - mean_average_precision(0): 0.5440734053089542\n",
      "Epoch 27/30\n",
      "102/102 [==============================] - 9s 90ms/step - loss: 0.0978\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5126531626078366 - normalized_discounted_cumulative_gain@5(0): 0.5914728905036133 - mean_average_precision(0): 0.5423304108444408\n",
      "Epoch 28/30\n",
      "102/102 [==============================] - 9s 90ms/step - loss: 0.1080\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.517827400521469 - normalized_discounted_cumulative_gain@5(0): 0.586751401222473 - mean_average_precision(0): 0.5441482116928879\n",
      "Epoch 29/30\n",
      "102/102 [==============================] - 9s 89ms/step - loss: 0.0989\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5203803639274984 - normalized_discounted_cumulative_gain@5(0): 0.5889616724308256 - mean_average_precision(0): 0.5456447313106599\n",
      "Epoch 30/30\n",
      "102/102 [==============================] - 9s 89ms/step - loss: 0.0948\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5218079967226515 - normalized_discounted_cumulative_gain@5(0): 0.5880608250425724 - mean_average_precision(0): 0.544506031754239\n"
     ]
    }
   ],
   "source": [
    "history = model.fit_generator(train_generator, epochs=30, callbacks=[evaluate], workers=30, use_multiprocessing=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
